
from __future__ import annotations

import argparse
import json
import os
from pathlib import Path
from typing import Dict, List

# ---------------------------------------------------------------------------- #
# Emotion vocabulary & mapping helpers                                         #
# ---------------------------------------------------------------------------- #

EMOTION_VOCAB: Dict[str, str] = {
    "active": "energetic, adventurous, vibrant, enthusiastic, playful",
    "afraid": "horrified, scared, fearful",
    "alarmed": "concerned, worried, anxious, overwhelmed",
    "alert": "attentive, curious",
    "amazed": "surprised, astonished, awed, fascinated, intrigued",
    "amused": "humored, laughing",
    "angry": "annoyed, irritated",
    "calm": "soothed, peaceful, comforted, fulfilled, cozy",
    "cheerful": "delighted, happy, joyful, carefree, optimistic",
    "confident": "assured, strong, healthy",
    "conscious": "aware, thoughtful, prepared",
    "creative": "inventive, productive",
    "disturbed": "disgusted, shocked",
    "eager": "hungry, thirsty, passionate",
    "educated": "informed, enlightened, smart, savvy, intelligent",
    "emotional": "vulnerable, moved, nostalgic, reminiscent",
    "empathetic": "sympathetic, supportive, understanding, receptive",
    "fashionable": "trendy, elegant, beautiful, attractive, sexy",
    "feminine": "womanly, girlish",
    "grateful": "thankful",
    "inspired": "motivated, ambitious, empowered, determined",
    "jealous": "jealous",
    "loving": "loved, romantic",
    "manly": "manly",
    "persuaded": "impressed, enchanted, immersed",
    "pessimistic": "skeptical",
    "proud": "patriotic",
    "sad": "depressed, upset, betrayed, distant",
    "thrifty": "frugal",
    "youthful": "childlike",
}

# The order **must** match indices in annotation.json (1-based).
EMOTION_ORDER: List[str] = [
    "active",
    "afraid",
    "alarmed",
    "alert",
    "amazed",
    "amused",
    "angry",
    "calm",
    "cheerful",
    "confident",
    "conscious",
    "creative",
    "disturbed",
    "eager",
    "educated",
    "emotional",
    "empathetic",
    "fashionable",
    "feminine",
    "grateful",
    "inspired",
    "jealous",
    "loving",
    "manly",
    "persuaded",
    "pessimistic",
    "proud",
    "sad",
    "thrifty",
    "youthful",
]

OPTION_TO_EMOTION: Dict[int, str] = {idx + 1: emo for idx, emo in enumerate(EMOTION_ORDER)}


# ---------------------------------------------------------------------------- #
# Utility functions                                                             #
# ---------------------------------------------------------------------------- #


def load_ground_truth(path: Path) -> Dict[str, str]:
    """Load ground-truth mapping *video_id* -> *emotion string*.

    Parameters
    ----------
    path : Path
        Path to *annotation.json* file.

    Returns
    -------
    Dict[str, str]
        Dictionary mapping video id to canonical emotion key.
    """
    with path.open("r", encoding="utf-8") as f:
        raw: Dict[str, int] = json.load(f)

    gt: Dict[str, str] = {}
    for vid, opt in raw.items():
        if not isinstance(opt, int):
            # Sometimes numbers are encoded as strings; try to cast.
            try:
                opt_int = int(opt)
            except (TypeError, ValueError):
                raise ValueError(f"Invalid option value {opt!r} for video {vid!r} in ground truth.")
        else:
            opt_int = opt

        emo = OPTION_TO_EMOTION.get(opt_int)
        if emo is None:
            raise KeyError(f"No emotion mapping found for option {opt_int} (video id: {vid}).")
        gt[vid] = emo
    return gt


def load_predictions(pred_dir: Path) -> List[tuple[str, str]]:
    """Load predictions from **all** *.json* files inside *pred_dir*.

    Returns a list of *(video_id, emotion)* pairs **including duplicates** so that
    we can compute accuracy both with and without deduplication.
    """
    records_all: List[tuple[str, str]] = []

    for json_path in pred_dir.glob("*.json"):
        try:
            with json_path.open("r", encoding="utf-8") as f:
                data = json.load(f)
        except json.JSONDecodeError as e:
            print(f"[WARN] Failed to parse {json_path.name}: {e}")
            continue

        # Some prediction files are a single object, others are a list of objects.
        if isinstance(data, list):
            recs_in_file = data
        else:
            recs_in_file = [data]

        for rec in recs_in_file:
            if not isinstance(rec, dict):
                print(f"[WARN] Unexpected record type in {json_path.name}: {type(rec).__name__}; skipping.")
                continue

            video_id: str | None = rec.get("video_id")
            if not video_id:
                video_id = json_path.stem  # fallback

            # accept either key name
            emotion: str | None = rec.get("final_topic") or rec.get("predicted_topic")
            if not emotion:
                print(f"[WARN] Missing emotion prediction for video {video_id} in {json_path.name}; skipping.")
                continue

            records_all.append((video_id, emotion))

    return records_all


# ------------------------------------------------------------------------- #
# Accuracy helpers
# ------------------------------------------------------------------------- #


def compute_accuracy_records(records: List[tuple[str, str]], gt: Dict[str, str]) -> tuple[int, int]:
    """Compute (correct, total) for a list of prediction records (may contain duplicates)."""
    correct = 0
    for vid, emotion_pred in records:
        if gt.get(vid) == emotion_pred:
            correct += 1
    return correct, len(records)


def compute_accuracy_unique(pred_unique: Dict[str, str], gt: Dict[str, str]) -> tuple[int, int]:
    """Compute (correct, total) using a mapping of unique video_id -> emotion."""
    correct = 0
    total = 0
    for vid, emotion_pred in pred_unique.items():
        emotion_gt = gt.get(vid)
        if emotion_gt is None:
            continue  # Unknown ID in ground truth
        total += 1
        if emotion_pred == emotion_gt:
            correct += 1
    return correct, total


def main() -> None:
    parser = argparse.ArgumentParser(description="Evaluate emotion prediction accuracy.")
    parser.add_argument("--pred_dir", type=str, required=True, help="Directory containing prediction JSON files.")
    parser.add_argument("--annot_file", type=str, required=True, help="Path to emotion_annotation.json ground-truth file.")
    parser.add_argument("--output", type=str, default="np_metric.txt", help="File to write accuracy to (default: metric.txt)")
    args = parser.parse_args()

    pred_dir = Path(args.pred_dir)
    annot_path = Path(args.annot_file)
    out_path = Path(args.output)

    if not pred_dir.is_dir():
        raise NotADirectoryError(f"Prediction directory not found: {pred_dir}")
    if not annot_path.is_file():
        raise FileNotFoundError(f"Annotation file not found: {annot_path}")

    # Load ground truth once
    gt_map = load_ground_truth(annot_path)

    def load_predictions_file(path: Path) -> List[tuple[str, str]]:
        """Load predictions only from the given JSON file path."""
        try:
            with path.open("r", encoding="utf-8") as f:
                data = json.load(f)
        except Exception as e:
            print(f"[WARN] Could not read {path.name}: {e}")
            return []

        recs = data if isinstance(data, list) else [data]
        pairs: List[tuple[str, str]] = []
        for rec in recs:
            if not isinstance(rec, dict):
                continue
            vid = rec.get("video_id") or path.stem
            emo = rec.get("final_topic") or rec.get("predicted_topic")
            if not emo:
                continue
            pairs.append((vid, emo))
        return pairs

    # Evaluate each JSON file separately
    with out_path.open("w", encoding="utf-8") as fout:
        for json_path in sorted(pred_dir.glob("*.json")):
            pred_records = load_predictions_file(json_path)

            # duplicates accuracy
            c_dup, t_dup = compute_accuracy_records(pred_records, gt_map)
            acc_dup = (c_dup / t_dup) if t_dup else 0.0

            # unique accuracy
            pred_unique: Dict[str, str] = {}
            for vid, emo in pred_records:
                pred_unique[vid] = emo  # last one wins

            c_u, t_u = compute_accuracy_unique(pred_unique, gt_map)
            acc_u = (c_u / t_u) if t_u else 0.0

            line = f"{json_path.name}: dup {acc_dup:.4f} | unique {acc_u:.4f}"
            print(line)
            fout.write(line + "\n")


if __name__ == "__main__":
    main()
